{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "db77d287-a072-41a9-9065-dd93e0052d39",
   "metadata": {},
   "source": [
    "# 基础层介绍"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a2128a94-4eca-4b3f-ab8b-cd13d6373d9f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-12T07:13:23.684949Z",
     "iopub.status.busy": "2022-08-12T07:13:23.684421Z",
     "iopub.status.idle": "2022-08-12T07:13:23.690606Z",
     "shell.execute_reply": "2022-08-12T07:13:23.689742Z",
     "shell.execute_reply.started": "2022-08-12T07:13:23.684895Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc50a886-86a9-4b62-b27e-64c8c381c58c",
   "metadata": {},
   "source": [
    "## RNN\n",
    "\n",
    "https://pytorch.org/docs/stable/generated/torch.nn.RNN.html#torch.nn.RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3e9c5291-dcde-42ab-b466-64a760f46fe4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-12T07:18:19.685648Z",
     "iopub.status.busy": "2022-08-12T07:18:19.685177Z",
     "iopub.status.idle": "2022-08-12T07:18:19.693496Z",
     "shell.execute_reply": "2022-08-12T07:18:19.692717Z",
     "shell.execute_reply.started": "2022-08-12T07:18:19.685591Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 10 input size\n",
    "# 20 hidden size\n",
    "# 2 layers\n",
    "rnn = nn.RNN(10, 20, 2, bidirectional=True)\n",
    "\n",
    "# 5 seqence length\n",
    "# 3 batch size\n",
    "# 10 input size\n",
    "input = torch.randn(5, 3, 10)\n",
    "output, hn = rnn(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c46c8123-ba0b-4511-85cd-70b4862d8861",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-12T07:18:20.685784Z",
     "iopub.status.busy": "2022-08-12T07:18:20.685209Z",
     "iopub.status.idle": "2022-08-12T07:18:20.692531Z",
     "shell.execute_reply": "2022-08-12T07:18:20.691789Z",
     "shell.execute_reply.started": "2022-08-12T07:18:20.685733Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 3, 40])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7dde556-6f4d-40fb-a28f-e3b8f6ce6f0c",
   "metadata": {},
   "source": [
    "## LSTM\n",
    "\n",
    "https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "cce8ed61-e2a4-440a-91df-1e085bc4ce6f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-11T03:19:52.329132Z",
     "iopub.status.busy": "2022-08-11T03:19:52.328569Z",
     "iopub.status.idle": "2022-08-11T03:19:52.337332Z",
     "shell.execute_reply": "2022-08-11T03:19:52.336686Z",
     "shell.execute_reply.started": "2022-08-11T03:19:52.329082Z"
    }
   },
   "outputs": [],
   "source": [
    "rnn = nn.LSTM(10, 20, 2)\n",
    "input = torch.randn(5, 3, 10)\n",
    "h0 = torch.randn(2, 3, 20)\n",
    "c0 = torch.randn(2, 3, 20)\n",
    "output, (hn, cn) = rnn(input, (h0, c0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "067aed05-330b-4e74-90d9-b1499676620d",
   "metadata": {},
   "source": [
    "## GRU\n",
    "\n",
    "https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2c04db7c-d252-47c7-b145-53b8f0288de3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-11T03:20:26.331512Z",
     "iopub.status.busy": "2022-08-11T03:20:26.331099Z",
     "iopub.status.idle": "2022-08-11T03:20:26.387099Z",
     "shell.execute_reply": "2022-08-11T03:20:26.385923Z",
     "shell.execute_reply.started": "2022-08-11T03:20:26.331466Z"
    }
   },
   "outputs": [],
   "source": [
    "rnn = nn.GRU(10, 20, 2)\n",
    "input = torch.randn(5, 3, 10)\n",
    "h0 = torch.randn(2, 3, 20)\n",
    "output, hn = rnn(input, h0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "957dd9d5-38fd-4ab6-9568-5b6b277119a8",
   "metadata": {},
   "source": [
    "## Embedding\n",
    "\n",
    "https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5180cfb4-9924-4df0-a8da-7c073d7427f5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-12T03:42:31.936798Z",
     "iopub.status.busy": "2022-08-12T03:42:31.936360Z",
     "iopub.status.idle": "2022-08-12T03:42:31.973403Z",
     "shell.execute_reply": "2022-08-12T03:42:31.972690Z",
     "shell.execute_reply.started": "2022-08-12T03:42:31.936767Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.5086, -0.0062,  0.3178],\n",
       "         [ 0.5990, -0.0073,  2.3685],\n",
       "         [ 0.3637,  1.6501,  0.7241],\n",
       "         [ 1.2083,  0.5309,  1.9538]],\n",
       "\n",
       "        [[ 0.3637,  1.6501,  0.7241],\n",
       "         [ 1.0144,  0.3747, -0.2003],\n",
       "         [ 0.5990, -0.0073,  2.3685],\n",
       "         [-0.6171,  0.6727,  0.5932]]], grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# an Embedding module containing 10 tensors of size 3\n",
    "embedding = nn.Embedding(10, 3)\n",
    "# a batch of 2 samples of 4 indices each\n",
    "input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])\n",
    "embedding(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f5fd3841-6a97-4989-95e4-09e518999a49",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-11T03:17:34.772274Z",
     "iopub.status.busy": "2022-08-11T03:17:34.771862Z",
     "iopub.status.idle": "2022-08-11T03:17:34.780661Z",
     "shell.execute_reply": "2022-08-11T03:17:34.780036Z",
     "shell.execute_reply.started": "2022-08-11T03:17:34.772228Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.0000,  0.0000,  0.0000],\n",
       "         [ 0.2401, -1.0441,  1.3407],\n",
       "         [ 0.0000,  0.0000,  0.0000],\n",
       "         [ 0.7517,  0.9654,  0.0865]]], grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# example with padding_idx\n",
    "embedding = nn.Embedding(10, 3, padding_idx=0)\n",
    "input = torch.LongTensor([[0,2,0,5]])\n",
    "embedding(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e024c0b3-1f6f-48ca-9805-4adfaf4f2b24",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-11T03:17:58.691274Z",
     "iopub.status.busy": "2022-08-11T03:17:58.690940Z",
     "iopub.status.idle": "2022-08-11T03:17:58.699149Z",
     "shell.execute_reply": "2022-08-11T03:17:58.698595Z",
     "shell.execute_reply.started": "2022-08-11T03:17:58.691253Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[ 1.0000,  1.0000,  1.0000],\n",
       "        [ 0.5622, -0.7949,  1.0729],\n",
       "        [ 0.4224, -0.1981,  0.0331]], requires_grad=True)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# example of changing `pad` vector\n",
    "padding_idx = 0\n",
    "embedding = nn.Embedding(3, 3, padding_idx=padding_idx)\n",
    "embedding.weight\n",
    "with torch.no_grad():\n",
    "    embedding.weight[padding_idx] = torch.ones(3)\n",
    "embedding.weight"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63ff58c7-7dd8-4145-bb39-e3c3d1b9af62",
   "metadata": {},
   "source": [
    "## EmbeddingBag\n",
    "\n",
    "https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag\n",
    "\n",
    "Computes sums or means of ‘bags’ of embeddings, without instantiating the intermediate embeddings. For bags of constant length, no per_sample_weights, no indices equal to padding_idx, and with 2D inputs, this class\n",
    "\n",
    "- with mode=\"sum\" is equivalent to Embedding followed by torch.sum(dim=1),\n",
    "- with mode=\"mean\" is equivalent to Embedding followed by torch.mean(dim=1),\n",
    "- with mode=\"max\" is equivalent to Embedding followed by torch.max(dim=1).\n",
    "\n",
    "input (Tensor) – Tensor containing bags of indices into the embedding matrix.\n",
    "\n",
    "offsets (Tensor, optional) – Only used when input is 1D. offsets determines the starting index position of each bag (sequence) in input.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b28a5731-3629-45b9-a0c0-424e6cf25bd1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-11T03:26:14.888980Z",
     "iopub.status.busy": "2022-08-11T03:26:14.888569Z",
     "iopub.status.idle": "2022-08-11T03:26:14.896736Z",
     "shell.execute_reply": "2022-08-11T03:26:14.896029Z",
     "shell.execute_reply.started": "2022-08-11T03:26:14.888932Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.0465,  2.6469, -0.3063],\n",
       "        [-0.2636,  2.8705, -2.7252]], grad_fn=<EmbeddingBagBackward0>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# an EmbeddingBag module containing 10 tensors of size 3\n",
    "embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')\n",
    "# a batch of 2 samples of 4 indices each\n",
    "input = torch.tensor([1,2,4,5,4,3,2,9], dtype=torch.long)\n",
    "offsets = torch.tensor([0,4], dtype=torch.long)\n",
    "embedding_sum(input, offsets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "050fab40-82e3-49d1-ad0c-8d9d55f77887",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-12T03:50:27.939273Z",
     "iopub.status.busy": "2022-08-12T03:50:27.938717Z",
     "iopub.status.idle": "2022-08-12T03:50:27.947668Z",
     "shell.execute_reply": "2022-08-12T03:50:27.947063Z",
     "shell.execute_reply.started": "2022-08-12T03:50:27.939225Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000,  0.0000,  0.0000],\n",
       "        [-1.1157,  1.3940, -0.6917],\n",
       "        [-0.2551,  1.1187,  1.2667]], grad_fn=<EmbeddingBagBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Example with padding_idx\n",
    "embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2)\n",
    "input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long)\n",
    "offsets = torch.tensor([0,3, 6], dtype=torch.long)\n",
    "embedding_sum(input, offsets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8d3224ab-fb57-4da0-b4f9-eb6af82f75d2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-12T03:50:28.932698Z",
     "iopub.status.busy": "2022-08-12T03:50:28.932142Z",
     "iopub.status.idle": "2022-08-12T03:50:28.940822Z",
     "shell.execute_reply": "2022-08-12T03:50:28.940395Z",
     "shell.execute_reply.started": "2022-08-12T03:50:28.932649Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[ 0.9081, -1.5726, -0.6032],\n",
       "        [-0.4929, -0.1607,  1.4178],\n",
       "        [ 0.0000,  0.0000,  0.0000],\n",
       "        [ 0.2250,  1.1902, -0.1070],\n",
       "        [-1.3407,  0.2037, -0.5847],\n",
       "        [ 0.6875, -0.3568, -0.7314],\n",
       "        [ 0.7148,  0.0956, -0.5155],\n",
       "        [-1.2188, -0.9206, -0.5927],\n",
       "        [ 0.6804, -0.7137, -0.3640],\n",
       "        [-0.2551,  1.1187,  1.2667]], requires_grad=True)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_sum.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "b62793e5-14b6-49fd-9792-cdc79918fcf6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-11T05:27:09.329946Z",
     "iopub.status.busy": "2022-08-11T05:27:09.328811Z",
     "iopub.status.idle": "2022-08-11T05:27:09.337069Z",
     "shell.execute_reply": "2022-08-11T05:27:09.336642Z",
     "shell.execute_reply.started": "2022-08-11T05:27:09.329868Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 0., 0.], grad_fn=<SumBackward1>)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_sum.weight[[2, 2, 2, 2]].sum(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "dd56d1ce-9d53-4eb5-b7c1-26ac670c3fbb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-11T05:27:11.862348Z",
     "iopub.status.busy": "2022-08-11T05:27:11.861781Z",
     "iopub.status.idle": "2022-08-11T05:27:11.870821Z",
     "shell.execute_reply": "2022-08-11T05:27:11.870363Z",
     "shell.execute_reply.started": "2022-08-11T05:27:11.862296Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.3092, 1.6480, 0.8876], grad_fn=<SumBackward1>)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_sum.weight[[ 4, 3, 2, 9]].sum(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d1bc7c9-d25f-4f0d-aff2-1acddfda915f",
   "metadata": {},
   "source": [
    "# Tagging案例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dbcfc61b-d826-49dd-abaa-6d757b65a6ce",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-11T03:14:32.325214Z",
     "iopub.status.busy": "2022-08-11T03:14:32.324803Z",
     "iopub.status.idle": "2022-08-11T03:14:32.332564Z",
     "shell.execute_reply": "2022-08-11T03:14:32.331753Z",
     "shell.execute_reply.started": "2022-08-11T03:14:32.325167Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'The': 0, 'dog': 1, 'ate': 2, 'the': 3, 'apple': 4, 'Everybody': 5, 'read': 6, 'that': 7, 'book': 8}\n"
     ]
    }
   ],
   "source": [
    "def prepare_sequence(seq, to_ix):\n",
    "    idxs = [to_ix[w] for w in seq]\n",
    "    return torch.tensor(idxs, dtype=torch.long)\n",
    "\n",
    "\n",
    "training_data = [\n",
    "    # Tags are: DET - determiner; NN - noun; V - verb\n",
    "    # For example, the word \"The\" is a determiner\n",
    "    (\"The dog ate the apple\".split(), [\"DET\", \"NN\", \"V\", \"DET\", \"NN\"]),\n",
    "    (\"Everybody read that book\".split(), [\"NN\", \"V\", \"DET\", \"NN\"])\n",
    "]\n",
    "word_to_ix = {}\n",
    "# For each words-list (sentence) and tags-list in each tuple of training_data\n",
    "for sent, tags in training_data:\n",
    "    for word in sent:\n",
    "        if word not in word_to_ix:  # word has not been assigned an index yet\n",
    "            word_to_ix[word] = len(word_to_ix)  # Assign each word with a unique index\n",
    "print(word_to_ix)\n",
    "tag_to_ix = {\"DET\": 0, \"NN\": 1, \"V\": 2}  # Assign each tag with a unique index\n",
    "\n",
    "# These will usually be more like 32 or 64 dimensional.\n",
    "# We will keep them small, so we can see how the weights change as we train.\n",
    "EMBEDDING_DIM = 6\n",
    "HIDDEN_DIM = 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0530d5c6-b014-43cd-a394-773e442ecdd0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-11T03:14:41.801920Z",
     "iopub.status.busy": "2022-08-11T03:14:41.801345Z",
     "iopub.status.idle": "2022-08-11T03:14:41.808358Z",
     "shell.execute_reply": "2022-08-11T03:14:41.807697Z",
     "shell.execute_reply.started": "2022-08-11T03:14:41.801870Z"
    }
   },
   "outputs": [],
   "source": [
    "class LSTMTagger(nn.Module):\n",
    "\n",
    "    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):\n",
    "        super(LSTMTagger, self).__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "\n",
    "        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)\n",
    "\n",
    "        # The LSTM takes word embeddings as inputs, and outputs hidden states\n",
    "        # with dimensionality hidden_dim.\n",
    "        self.lstm = nn.LSTM(embedding_dim, hidden_dim)\n",
    "\n",
    "        # The linear layer that maps from hidden state space to tag space\n",
    "        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)\n",
    "\n",
    "    def forward(self, sentence):\n",
    "        embeds = self.word_embeddings(sentence)\n",
    "        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))\n",
    "        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))\n",
    "        tag_scores = F.log_softmax(tag_space, dim=1)\n",
    "        return tag_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "030f25d4-97b6-42f1-af47-79b5688f5610",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-08-11T03:14:47.329196Z",
     "iopub.status.busy": "2022-08-11T03:14:47.328782Z",
     "iopub.status.idle": "2022-08-11T03:14:48.323259Z",
     "shell.execute_reply": "2022-08-11T03:14:48.322610Z",
     "shell.execute_reply.started": "2022-08-11T03:14:47.329151Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.8365, -1.2553, -1.2666],\n",
      "        [-0.8419, -1.2320, -1.2824],\n",
      "        [-0.8342, -1.3037, -1.2233],\n",
      "        [-0.8600, -1.2047, -1.2836],\n",
      "        [-0.7945, -1.3815, -1.2141]])\n",
      "tensor([[-0.1163, -3.0192, -2.7979],\n",
      "        [-3.9386, -0.0217, -6.2021],\n",
      "        [-2.6727, -5.6979, -0.0752],\n",
      "        [-0.0501, -3.6948, -3.7306],\n",
      "        [-3.8031, -0.0245, -6.2411]])\n"
     ]
    }
   ],
   "source": [
    "model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))\n",
    "loss_function = nn.NLLLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.1)\n",
    "\n",
    "# See what the scores are before training\n",
    "# Note that element i,j of the output is the score for tag j for word i.\n",
    "# Here we don't need to train, so the code is wrapped in torch.no_grad()\n",
    "with torch.no_grad():\n",
    "    inputs = prepare_sequence(training_data[0][0], word_to_ix)\n",
    "    tag_scores = model(inputs)\n",
    "    print(tag_scores)\n",
    "\n",
    "for epoch in range(300):  # again, normally you would NOT do 300 epochs, it is toy data\n",
    "    for sentence, tags in training_data:\n",
    "        # Step 1. Remember that Pytorch accumulates gradients.\n",
    "        # We need to clear them out before each instance\n",
    "        model.zero_grad()\n",
    "\n",
    "        # Step 2. Get our inputs ready for the network, that is, turn them into\n",
    "        # Tensors of word indices.\n",
    "        sentence_in = prepare_sequence(sentence, word_to_ix)\n",
    "        targets = prepare_sequence(tags, tag_to_ix)\n",
    "\n",
    "        # Step 3. Run our forward pass.\n",
    "        tag_scores = model(sentence_in)\n",
    "\n",
    "        # Step 4. Compute the loss, gradients, and update the parameters by\n",
    "        #  calling optimizer.step()\n",
    "        loss = loss_function(tag_scores, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "# See what the scores are after training\n",
    "with torch.no_grad():\n",
    "    inputs = prepare_sequence(training_data[0][0], word_to_ix)\n",
    "    tag_scores = model(inputs)\n",
    "\n",
    "    # The sentence is \"the dog ate the apple\".  i,j corresponds to score for tag j\n",
    "    # for word i. The predicted tag is the maximum scoring tag.\n",
    "    # Here, we can see the predicted sequence below is 0 1 2 0 1\n",
    "    # since 0 is index of the maximum value of row 1,\n",
    "    # 1 is the index of maximum value of row 2, etc.\n",
    "    # Which is DET NOUN VERB DET NOUN, the correct sequence!\n",
    "    print(tag_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af599bf9-a8ed-4c97-8a16-2a50125d88b2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
